Skip to content

fixed bug, separated function#145

Merged
skyw merged 3 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/fix_riemann_sphere
Mar 23, 2026
Merged

fixed bug, separated function#145
skyw merged 3 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/fix_riemann_sphere

Conversation

@mkhona-nvidia
Copy link
Contributor

Address issue #136

Signed-off-by: mikail <mkhona@nvidia.com>
@mkhona-nvidia mkhona-nvidia requested a review from skyw March 23, 2026 16:13
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 23, 2026

Greptile Summary

This PR fixes a critical bug in ObliqueSGD where the momentum buffer was never actually persisted across optimizer steps — buf = torch.add(...) created a new tensor each step without updating state["momentum_buffer"], so the optimizer silently ran with zero effective momentum. The fix correctly uses torch.add(..., out=buf) to update the buffer in-place. Alongside this, the helper function _compute_riemannian_grad_and_update is split into a pure _compute_riemannian_grad and inline weight-decay/update logic, enabling both ObliqueSGD and ObliqueAdam to integrate with the existing WeightDecayMixin and support configurable weight-decay methods (decoupled, independent, l2).

  • Bug fix is correct — the in-place out=buf write is the right approach; a new test (test_oblique_sgd_momentum_buffer_accumulates_across_steps) directly validates both-step buffer accumulation.
  • Convergence threshold lowered 50 % → 40 % — since working momentum should improve (not hurt) convergence, this reduction is suspicious and may indicate the high weight_decay=0.1 + now-active momentum causes instability on this test setup, or the test was always flaky without a fixed seed.
  • weight_decay_method not in defaults dict — the parameter is stored only as self.weight_decay_method, so it is not preserved by state_dict()/load_state_dict() and cannot vary per param group, unlike the other optimizer hyperparameters (lr, momentum, weight_decay).

Confidence Score: 4/5

  • Safe to merge after confirming why the convergence accuracy threshold was lowered
  • The core bug fix is correct and well-tested. The refactoring is clean and the WeightDecayMixin integration is consistent with the rest of the codebase. One P1 concern remains: the convergence threshold reduction from 50% to 40% is counter-intuitive after fixing a bug that should improve convergence, and warrants a brief explanation or investigation before merge. The missing weight_decay_method in defaults is a follow-up style concern that does not block correctness for current usage.
  • tests/convergence/normalized_optimizer_test.py — the accuracy threshold reduction needs justification

Important Files Changed

Filename Overview
emerging_optimizers/riemannian_optimizers/normalized_optimizer.py Core bug fixed (momentum buffer was reassigned instead of updated in-place). Function refactored into _compute_riemannian_grad + WeightDecayMixin integration. Minor concern: weight_decay_method not in defaults dict, limiting per-group config and serialization.
tests/test_normalized_optimizer.py New test correctly validates momentum buffer persistence across steps. Existing test parameterized over momentum values. Both changes are well-structured.
tests/convergence/normalized_optimizer_test.py Convergence threshold lowered from 50% to 40% after bug fix that enables real momentum — this is counter-intuitive and may hide a regression; deserves investigation before merge.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant ObliqueSGD/ObliqueAdam
    participant WeightDecayMixin
    participant _compute_riemannian_grad

    Caller->>ObliqueSGD/ObliqueAdam: step()
    ObliqueSGD/ObliqueAdam->>ObliqueSGD/ObliqueAdam: update momentum buffer in-place<br/>torch.add(grad, buf, alpha=mom, out=buf)
    ObliqueSGD/ObliqueAdam->>_compute_riemannian_grad: (param, buf/norm_grad, dim)
    _compute_riemannian_grad-->>ObliqueSGD/ObliqueAdam: riem_grad (tangent-space projected)
    ObliqueSGD/ObliqueAdam->>WeightDecayMixin: _apply_weight_decay_inplace(param, riem_grad, lr, wd)
    Note over WeightDecayMixin: decoupled: param *= (1 - lr*wd)<br/>independent: param *= (1 - wd)<br/>l2: riem_grad += wd*param
    WeightDecayMixin-->>ObliqueSGD/ObliqueAdam: param or riem_grad updated in-place
    ObliqueSGD/ObliqueAdam->>ObliqueSGD/ObliqueAdam: param.add_(riem_grad, alpha=-lr)
    ObliqueSGD/ObliqueAdam->>ObliqueSGD/ObliqueAdam: normalize(param) — retract to manifold
Loading

Comments Outside Diff (2)

  1. tests/convergence/normalized_optimizer_test.py, line 214 (link)

    Convergence bar reduced — possibly hiding a regression

    The bug fix makes momentum accumulation work correctly for the first time (the old buf = torch.add(...) discarded state every step, effectively running the optimizer with zero momentum). Enabling real momentum should improve or maintain convergence, so lowering the accuracy threshold from 50 % to 40 % is counter-intuitive and may signal an undetected regression.

    Possible explanations worth investigating before merge:

    • Weight-decay interaction: the old broken momentum silently compensated for the high weight_decay=0.1 used in all four test cases. With real momentum now active, the effective update magnitude is larger and may cause divergence on this particular test setup.
    • Test flakiness: no fixed seed is set in setUpModule for the convergence suite (seed is None by default), so results vary between runs.

    Consider adding a fixed seed for the convergence test and restoring the 50 % threshold (or explaining why 40 % is the correct expectation for a correctly-implemented Riemannian SGD with momentum).

  2. emerging_optimizers/riemannian_optimizers/normalized_optimizer.py, line 72-80 (link)

    weight_decay_method missing from defaults dict

    weight_decay_method is stored as a plain instance attribute (self.weight_decay_method) but is not added to the defaults dict alongside lr, momentum, weight_decay, etc. This has two practical consequences:

    1. Serialization gapoptimizer.state_dict() / load_state_dict() saves and restores param_groups (which come from defaults), but weight_decay_method is not in param_groups, so it is silently dropped on round-trip serialization. The same issue exists in ObliqueAdam (line ~175–183).

    2. Per-group configuration impossible — Other hyperparameters can vary between param groups, but weight_decay_method cannot because _apply_weight_decay_inplace reads self.weight_decay_method rather than the per-group value.

    Consider including it in defaults and reading it from the group dict in the step loop (consistent with how lr, momentum, etc. are handled), or at least documenting the current limitation. The same fix applies to ObliqueAdam.__init__ (line 183).

    defaults = dict(
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        weight_decay_method=weight_decay_method,
        dim=dim,
        eps=eps,
    )
    # self.weight_decay_method = weight_decay_method  ← can be removed

Reviews (3): Last reviewed commit: "addressed failing test" | Re-trigger Greptile

Comment on lines +178 to +183
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
expected_buffer,
atol=0,
rtol=0,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Strict zero-tolerance assertion may be fragile on second step

The second momentum-buffer assertion uses atol=0, rtol=0 against expected_buffer = second_grad + 0.8 * first_grad. The optimizer computes buf.mul_(0.8).add_(second_grad) while the expected value is computed as second_grad + (0.8 * first_grad) — two different Python/PyTorch expressions. Float32 addition is commutative (a + b == b + a) so the values are identical in this specific case, but the ordering of operations differs and could diverge on other hardware/precision modes.

Consider using a small tolerance to make the test more robust:

Suggested change
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
expected_buffer,
atol=0,
rtol=0,
)
torch.testing.assert_close(
optimizer.state[param]["momentum_buffer"],
expected_buffer,
atol=1e-6,
rtol=1e-6,
)

Copy link
Contributor

@skyw skyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor things. otherwise LGTM

riem_grad = _compute_riemannian_grad(param, buf, dim)

# Apply the weight update
param.mul_(1 - lr * wd)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: would addmm be enough for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this step is done this way for literally every other optimizer in this repo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me the reason? I vaguely remember one of the weight decay type can't be done in single addmm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also forgot, but we made the opt_mixin for this

"""Test that ObliqueSGD persists momentum state across optimization steps."""
param = torch.tensor(
[[1.0, 0.0], [0.0, 1.0]],
dtype=torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dtype=torch.float32 is not necessary. default dtype is almost never changed.

Signed-off-by: mikail <mkhona@nvidia.com>
skyw
skyw previously approved these changes Mar 23, 2026
@skyw
Copy link
Contributor

skyw commented Mar 23, 2026

/ok to test ec6b492

Signed-off-by: mikail <mkhona@nvidia.com>
@mkhona-nvidia
Copy link
Contributor Author

/ok to test 7091f0c

@skyw skyw merged commit 4864796 into NVIDIA-NeMo:main Mar 23, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants